#!/usr/bin/env python

import argparse
import logging
import os
import re
import subprocess
import sys
import time
from concurrent.futures.thread import ThreadPoolExecutor
from functools import partial
from threading import Lock

import config
from tools.utils import CountDownLatch

try:
    import cPickle as pickle
except ImportError:
    import pickle

# Import from parent directory
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

"""
Introduction: 从指定的git repo目录中根据关键字（keyword=CVE-20）在log中提取fix cve漏洞的
commit，然后git show commit来提取fix的commit具体的diff
"""

PATTERN_COMMIT = re.compile(r'[\n](?=commit\s\w{40}\nAuthor:\s)|[\n](?=commit\s\w{40}\nMerge:\s)')


class InfoStruct:
    RepoName = ''  # repository name
    OriginalDir = ''  # vuddy root directory
    DiffDir = ''
    MultimodeFlag = 0
    MultiRepoList = []
    GitBinary = config.GIT_BIN
    GitStoragePath = config.GIT_STORAGE_PATH
    CveDict = {}
    keyword = config.CVE_KEYWORD
    cveID = None
    DebugMode = config.DEBUG_MODE

    def __init__(self, originalDir, CveDataPath):
        self.OriginalDir = originalDir
        self.DiffDir = os.path.join(originalDir, 'diff')
        with open(CveDataPath, "rb") as f:
            self.CveDict = pickle.load(f)


""" GLOBALS """
originalDir = config.ROOT_PATH  # vuddy root directory
cveDataPath = os.path.join(originalDir, "data", "cvedata.pkl")
info = InfoStruct(originalDir, cveDataPath)  # first three arg is dummy for now
printLock = Lock()

""" FUNCTIONS """


def parse_argument():
    global info

    parser = argparse.ArgumentParser(prog='get_cvepatch_from_git.py')
    parser.add_argument('REPO',
                        help='''Repository name''')
    parser.add_argument('-m', '--multimode', action="store_true",
                        help='''Turn on Multimode''')
    parser.add_argument('-k', '--keyword',
                        help="Keyword to GREP, default: CVE-20", default="CVE-20")
    parser.add_argument('-c', '--cveid', help="CVE id to assign (Only when doing manual keyword search)")
    parser.add_argument('-d', '--debug', action="store_true", help=argparse.SUPPRESS)  # Hidden Debug Mode

    args = parser.parse_args()

    info.RepoName = args.REPO
    info.keyword = args.keyword
    info.cveID = args.cveid
    info.MultimodeFlag = 0
    info.MultiRepoList = []
    if args.multimode:
        info.MultimodeFlag = 1
        if config.os_is_windows():
            with open(os.path.join(originalDir, 'data', 'repolists', 'list_' + info.RepoName)) as fp:
                for repoLine in fp.readlines():
                    if len(repoLine) > 2:
                        info.MultiRepoList.append(repoLine.rstrip())
        else:
            repoBaseDir = os.path.join(info.GitStoragePath, info.RepoName)
            command_find = "find " + repoBaseDir + " -type d -exec test -e '{}/.git' ';' -logging.info(-prune"
            findOutput = subprocess.check_output(command_find, shell=True).decode(encoding=config.ENCODING)
            info.MultiRepoList = findOutput.replace(repoBaseDir + "/", "").rstrip().split("\n")
    if args.debug:
        info.DebugMode = True


def init():
    global info

    parse_argument()

    logging.info("Retrieving CVE patch from %s" % info.RepoName)
    logging.info("Multi-repo mode:", )
    if info.MultimodeFlag:
        logging.info("ON.")
    else:
        logging.info("OFF.")

    logging.info("Initializing...", )

    try:
        os.makedirs(os.path.join(info.DiffDir, info.RepoName))
    except OSError:
        pass

    logging.info("Done.")


def callGitLog(gitDir):
    global info
    """
    Collect CVE commit log from repository
    :param gitDir: repository path
    :return:
    """
    # logging.info("Calling git log...",
    commitsList = []
    command_log = "\"{0}\" --no-pager log --all --pretty=fuller --grep=\"{1}\"".format(info.GitBinary, info.keyword)
    logging.info(gitDir)
    os.chdir(gitDir)
    try:
        try:
            gitLogOutput = subprocess.check_output(command_log, shell=True).decode(config.ENCODING)
            commitsList = PATTERN_COMMIT.split(gitLogOutput)
        except subprocess.CalledProcessError as e:
            logging.exception("[-] Git log error:", e)
    except UnicodeDecodeError as err:
        logging.exception("[-] Unicode error:", err)

    # logging.info("Done."
    return commitsList


PATTERN_FILTER_COMMIT = [
    re.compile(r"\W" + kwd + r"\W|\W" + kwd + r"s\W") for kwd in ["merge", "revert", "upgrade"]
]


def filterCommitMessage(commitMessage):
    """
    Filter false positive commits 
    Will remove 'Merge', 'Revert', 'Upgrade' commit log
    :param commitMessage: commit message
    :return: 
    """
    matchCnt = 0
    for pattern in PATTERN_FILTER_COMMIT:
        match = pattern.search(commitMessage.lower())
        # bug fixed.. now revert and upgrade commits will be filtered out.
        if match:
            matchCnt += 1

    if matchCnt > 0:
        return 1
    else:
        return 0


def callGitShow(gitBinary, commitHashValue):
    """
    Grep data of git show
    :param commitHashValue: 
    :return: 
    """
    # logging.info("Calling git show...",
    command_show = "\"{0}\" show --pretty=fuller {1}".format(gitBinary, commitHashValue)

    gitShowOutput = ''
    try:
        gitShowOutput = subprocess.check_output(command_show, shell=True)
    except subprocess.CalledProcessError as e:
        logging.exception("error:", e)

    # logging.info("Done."
    return gitShowOutput


def updateCveInfo(cveDict, cveId):
    """
    Get CVSS score and CWE id from CVE id
    :param cveId: 
    :return: 
    """
    # logging.info("Updating CVE metadata...",
    try:
        cvss = str(cveDict[cveId][0])
    except:
        cvss = "0.0"
    if len(cvss) == 0:
        cvss = "0.0"

    try:
        cwe = cveDict[cveId][1]
    except:
        cwe = "CWE-000"
    if len(cwe) == 0:
        cwe = "CWE-000"
    else:
        cweNum = cwe.split('-')[1].zfill(3)
        cwe = "CWE-" + str(cweNum)

    # logging.info("Done."
    return cveId + '_' + cvss + '_' + cwe + '_'


def process(commitsList, subRepoName):
    global info

    flag = 0
    if len(commitsList) > 0 and commitsList[0] == '':
        flag = 1
        logging.info("No commit in %s", info.RepoName)
    else:
        logging.info("%s commits in %s" % (len(commitsList), info.RepoName))
    if subRepoName is None:
        logging.info("\n")
    else:
        logging.info(subRepoName)
        os.chdir(os.path.join(info.GitStoragePath, info.RepoName, subRepoName))

    if flag:
        return

    if info.DebugMode:
        for commitMessage in commitsList:
            parallel_process(subRepoName, commitMessage)
    else:  # use thread pool
        cdl = CountDownLatch(count=len(commitsList))
        parallel_partial = partial(parallel_process, subRepoName, cdl=cdl)
        with ThreadPoolExecutor() as pool:
            pool.map(parallel_partial, commitsList)
            cdl.wait()


PATTERN_CVE = re.compile('CVE-20\d{2}-\d{4,7}')


def count_down(func):
    def wrapper(*args, **kwargs):
        cdl: CountDownLatch = kwargs.get('cdl', None)
        ret = func(*args, **kwargs)
        if cdl:
            cdl.count_down()
        return ret
    return wrapper


@count_down
def parallel_process(subRepoName, commitMessage, cdl: CountDownLatch = None):
    global info
    global printLock

    if filterCommitMessage(commitMessage):
        return
    else:
        commitHashValue = commitMessage[7:47]

        cvePattern = PATTERN_CVE  # note: CVE id can now be 7 digit numbers
        cveIdList = list(set(cvePattern.findall(commitMessage)))

        """    
        Note: Aug 5, 2016
        If multiple CVE ids are assigned to one commit,
        store the dependency in a file which is named after
        the repo, (e.g., ~/diff/dependency_ubuntu)    and use
        one representative CVE that has the smallest ID number
        for filename. 
        A sample:
        CVE-2014-6416_2e9466c84e5beee964e1898dd1f37c3509fa8853    CVE-2014-6418_CVE-2014-6417_CVE-2014-6416_
        """

        if len(cveIdList) > 1:  # do this only if muliple CVEs are assigned to a commit
            dependency = os.path.join(info.DiffDir, "dependency_" + info.RepoName)
            with open(dependency, "a") as fp:
                cveIdFull = ""
                minCve = ""
                minimum = 9999999
                for cveId in cveIdList:
                    idDigits = int(cveId.split('-')[2])
                    cveIdFull += cveId + '_'
                    if minimum > idDigits:
                        minimum = idDigits
                        minCve = cveId
                fp.write(str(minCve + '_' + commitHashValue + '\t' + cveIdFull + '\n'))
        elif len(cveIdList) == 0:
            if info.cveID is None:
                return
            else:
                minCve = info.cveID  # when CVE ID is given manually through command line argument
        else:
            minCve = cveIdList[0]

        gitShowOutput = callGitShow(info.GitBinary, commitHashValue)

        try:
            gitShowOutput = gitShowOutput.decode(encoding=config.ENCODING)
        except Exception as e:
            logging.exception("[+] Parse Error", e)
            return

        finalFileName = updateCveInfo(info.CveDict, minCve)

        diffFileName = "{0}{1}.diff".format(finalFileName, commitHashValue)
        try:
            with open(os.path.join(info.DiffDir, info.RepoName, diffFileName), "w", encoding=config.ENCODING) as fp:
                if subRepoName is None:
                    fp.write(gitShowOutput)
                else:  # multi-repo mode
                    fp.write(subRepoName + '\n' + gitShowOutput)
            with printLock:
                logging.info("[+] Writing {0} Done.".format(diffFileName))
        except IOError as e:
            with printLock:
                logging.exception("[+] Writing {0} Error:".format(diffFileName), e)


def main():
    global info

    t1 = time.time()
    init()
    if info.MultimodeFlag:
        for sidx, subRepoName in enumerate(info.MultiRepoList):
            gitDir = os.path.join(info.GitStoragePath, info.RepoName, subRepoName)  # where .git exists
            commitsList = callGitLog(gitDir)
            logging.info(os.path.join(str(sidx + 1), str(len(info.MultiRepoList))))
            if 0 < len(commitsList):
                process(commitsList, subRepoName)
    else:
        gitDir = os.path.join(info.GitStoragePath, info.RepoName)  # where .git exists
        commitsList = callGitLog(gitDir)
        process(commitsList, None)

    repoDiffDir = os.path.join(info.DiffDir, info.RepoName)
    logging.info("%s patches saved in %s" % (len(os.listdir(repoDiffDir)), repoDiffDir))
    logging.info("Done. (%ssec)" % str(time.time() - t1))


if __name__ == '__main__':
    config.conf_log()
    main()
